{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6ca307cd",
   "metadata": {},
   "source": [
    "# Mixed Circular and Normal Neural Spline Flow"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b06d613c",
   "metadata": {},
   "source": [
    "This is a Neural Spline Flow model which has circularand unbounded random variables combined in one random vector."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e36d1a8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import packages\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "import normflows as nf\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76ddb9e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up target\n",
    "class Target:\n",
    "    def __init__(self, ndim, ind_circ):\n",
    "        self.ndim = ndim\n",
    "        self.ind_circ = ind_circ  \n",
    "    \n",
    "    def sample(self, n):\n",
    "        s = torch.randn(n, self.ndim)\n",
    "        c = torch.rand(n, self.ndim) > 0.6\n",
    "        s = c * (0.3 * s - 0.5) + (1 - 1. * c) * (s + 1.3)\n",
    "        u = torch.rand(n, len(self.ind_circ))\n",
    "        s_ = torch.acos(2 * u - 1)\n",
    "        c = torch.rand(n, len(self.ind_circ)) > 0.3\n",
    "        s_[c] = -s_[c]\n",
    "        s[:, self.ind_circ] = (s_ + 1) % (2 * np.pi) - np.pi\n",
    "        return s\n",
    "    \n",
    "# Visualize target\n",
    "target = Target(2, [1])\n",
    "s = target.sample(1000000)\n",
    "plt.hist(s[:, 0].data.numpy(), bins=200)\n",
    "plt.show()\n",
    "plt.hist(s[:, 1].data.numpy(), bins=200)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d31852ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "base = nf.distributions.UniformGaussian(2, [1], torch.tensor([1., 2 * np.pi]))\n",
    "\n",
    "# Visualize base\n",
    "s = base.sample(1000000)\n",
    "plt.hist(s[:, 0].data.numpy(), bins=200)\n",
    "plt.show()\n",
    "plt.hist(s[:, 1].data.numpy(), bins=200)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86ebe83e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create normalizing flow\n",
    "K = 20\n",
    "\n",
    "flow_layers = []\n",
    "for i in range(K):\n",
    "    flow_layers += [nf.flows.CircularAutoregressiveRationalQuadraticSpline(2, 1, 128, [1], \n",
    "                                                                           tail_bound=torch.tensor([5., np.pi]),\n",
    "                                                                           permute_mask=True)]\n",
    "\n",
    "model = nf.NormalizingFlow(base, flow_layers)\n",
    "\n",
    "# Move model on GPU if available\n",
    "enable_cuda = True\n",
    "device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3acd581d",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    s, _ = model.sample(50000)\n",
    "model.train()\n",
    "plt.hist(s[:, 0].cpu().data.numpy(), bins=100)\n",
    "plt.show()\n",
    "plt.hist(s[:, 1].cpu().data.numpy(), bins=100)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a8cfa4f",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Train model\n",
    "max_iter = 20000\n",
    "num_samples = 2 ** 10\n",
    "show_iter = 5000\n",
    "\n",
    "\n",
    "loss_hist = np.array([])\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)\n",
    "for it in tqdm(range(max_iter)):\n",
    "    optimizer.zero_grad()\n",
    "    \n",
    "    # Get training samples\n",
    "    x = target.sample(num_samples)\n",
    "    \n",
    "    # Compute loss\n",
    "    loss = model.forward_kld(x.to(device))\n",
    "    \n",
    "    # Do backprop and optimizer step\n",
    "    if ~(torch.isnan(loss) | torch.isinf(loss)):\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    \n",
    "    # Log loss\n",
    "    loss_hist = np.append(loss_hist, loss.to('cpu').data.numpy())\n",
    "    \n",
    "    # Plot learned density\n",
    "    if (it + 1) % show_iter == 0:\n",
    "        model.eval()\n",
    "        with torch.no_grad():\n",
    "            s, _ = model.sample(50000)\n",
    "        model.train()\n",
    "        plt.hist(s[:, 0].cpu().data.numpy(), bins=100)\n",
    "        plt.show()\n",
    "        plt.hist((s[:, 1].cpu().data.numpy() - 1) % (2 * np.pi), bins=100)\n",
    "        plt.show()\n",
    "\n",
    "# Plot loss\n",
    "plt.figure(figsize=(10, 10))\n",
    "plt.plot(loss_hist, label='loss')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
